# -*- coding: utf-8 -*-
"""
Created on Fri Sep 17 10:29:29 2021
用 tf2 写一个从 任意态 态制备到 |0> 态的 DQN 算法
@author: Waikikilick
"""

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' #设置 tf 的报警等级
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,optimizers,losses,initializers,Sequential,metrics,models
import copy 
from collections import deque
import random
from scipy.linalg import expm
from time import *

tf.random.set_seed(1)
np.random.seed(1)
random.seed (1)

class Agent(object):
    def __init__(self, 
            n_actions=4,
            n_features=4,
            learning_rate=0.0001,
            reward_decay=0.9,
            e_greedy=0.99,
            replace_target_iter=250,
            memory_size=2000,
            batch_size=32,
            e_greedy_increment=None):
        
        self.n_actions = n_actions
        self.n_features = n_features
        self.lr = learning_rate
        self.gamma = reward_decay
        self.epsilon_max = e_greedy
        self.replace_target_iter = replace_target_iter
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.epsilon_increment = e_greedy_increment
        self.epsilon = 0 if self.epsilon_increment is not None else self.epsilon_max
        self.learn_step_counter = 0
        self.memory = deque(maxlen=self.memory_size)
        self.memory_counter = 0
        
    def create_model(self): # 创建模型
    
        model = tf.keras.Sequential() # 模型搭建
        model.add(tf.keras.layers.Dense(32, input_dim=self.n_features, activation='relu',kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.05), bias_initializer=initializers.Constant(value=0.1)))
        model.add(tf.keras.layers.Dense(32, activation='relu',kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.05), bias_initializer=initializers.Constant(value=0.1)))
        model.add(tf.keras.layers.Dense(self.n_actions, activation='relu',kernel_initializer=initializers.RandomNormal(mean=0.0, stddev=0.05), bias_initializer=initializers.Constant(value=0.1)))
        
        model.compile(loss='mse',optimizer=tf.keras.optimizers.RMSprop(learning_rate=self.lr)) # 模型装配
        return model
        
    def save_models(self): # 保存模型
    
        self.model.save('dqn_tf2_single_qubit_to_1_saved_model')
        self.target_model.save('dqn_tf2_single_qubit_to_1_saved_target_model')
        print('models saved !')
           
    def store_transition(self, s, a, r, s_): # 保存 记忆单元
    
        self.memory.append((s, a, r, s_)) 
        self.memory_counter += 1

    def choose_action(self, observation, tanxin): # tanxin 的值代表着 预测动作的 策略选择
      # tanxin = 0 意味着完全随机取动作
      # tanxin = 0.5 意味着执行动态贪心策略
      # tanxin = 1 意味着完全靠网络选择动作
        
        if tanxin == 0: # tanxin = 0, 意味着完全随机选动作
            action = np.random.randint(0, self.n_actions)
        elif tanxin == 1: # tanxin = 1, 意味着全靠网络预测动作, 通常在测试阶段用 # 其他值的话就意味着动作选择概率处于动态调整中, 比如 tanxin = 0.5
            demo = tf.reshape(observation,(1,4))	
            action = np.argmax(self.model.predict(demo)[0])
        else:
            if np.random.uniform() < self.epsilon:
                demo = tf.reshape(observation,(1,4))	
                action = np.argmax(self.model.predict(demo)[0])
            else:
                action = np.random.randint(0, self.n_actions)
            
        return action        
        
        
    def learn(self):
        
        if self.learn_step_counter % self.replace_target_iter == 0: # 每 replace_target_iter 次学习, 更新 target 网络参数一次
            self.target_model.set_weights(self.model.get_weights()) # target 网络参数从 主网络完全复制而来
                    
        if self.memory_counter < self.batch_size: # 只有当 记忆库中的数据量 比 批大小 多的时候才开始学习
            return # 如果 if 成立, 那么 return 后面的 且与 if 同级的代码不再执行
                
        batch_memory = random.sample(self.memory, self.batch_size) # 从 记忆库 中随机选出 批大小 数量的样本
        s_batch = np.array([replay[0] for replay in batch_memory]) # 多行数组
        demo_s_batch = tf.reshape(s_batch,(self.batch_size,4))
        Q = self.model.predict(demo_s_batch)
        
        next_s_batch = np.array([replay[3] for replay in batch_memory])
        demo_next_s_batch = tf.reshape(next_s_batch,(self.batch_size,4))
        Q_next = self.target_model.predict(demo_next_s_batch)

        # 使用公式更新训练集中的Q值
        for i, replay in enumerate(batch_memory):
            _, a, reward, _ = replay
            Q[i][a] =   reward + self.gamma * np.amax(Q_next[i]) # Q 值计算法则
 
        history = self.model.fit(demo_s_batch, Q, verbose=0) # 传入网络进行训练 # history 是记录每次训练时的 损失函数 字典 
            
        self.epsilon = self.epsilon + self.epsilon_increment if self.epsilon < self.epsilon_max else self.epsilon_max
        self.learn_step_counter += 1
      

class env(object):
    def  __init__(self, 
        action_space = [0,1], #允许的动作，默认两个分立值，只是默认值，真正值由调用时输入
        dt = 0.1,
        noise_a = 0,
        ): 
        self.action_space = action_space
        self.n_actions = len(self.action_space)
        self.n_features = 4 #描述状态所用的长度
        self.target_psi =  np.mat([[0], [1]], dtype=complex) #最终的目标态为 |1>,  np.array([0,1,0,0])
        self.s_x = np.mat([[0, 1], [1, 0]], dtype=complex)
        self.s_z = np.mat([[1, 0], [0, -1]], dtype=complex)
        self.dt = dt
        self.training_set, self.validation_set, self.testing_set = self.psi_set()
        
        self.noise_a = noise_a
        self.noise_normal = np.array([ 1.62434536, -0.61175641, -0.52817175, -1.07296862,  0.86540763,
       -2.3015387 ,  1.74481176, -0.7612069 ,  0.3190391 , -0.24937038,
        1.46210794, -2.06014071, -0.3224172 , -0.38405435,  1.13376944,
       -1.09989127, -0.17242821, -0.87785842,  0.04221375,  0.58281521,
       -1.10061918,  1.14472371,  0.90159072,  0.50249434,  0.90085595,
       -0.68372786, -0.12289023, -0.93576943, -0.26788808,  0.53035547,
       -0.69166075, -0.39675353, -0.6871727 , -0.84520564, -0.67124613,
       -0.0126646 , -1.11731035,  0.2344157 ,  1.65980218,  0.74204416,
       -0.19183555])
        #noise_normal 为均值为 0 ，标准差为 1 的正态分布的随机数组成的数组
        #该随机数由 np.random.seed(1) 生成: np.random.seed(1) \ noise_uniform = np.random.normal(loc=0.0, scale=1.0, size=41)
        self.noise = self.noise_a * self.noise_normal #uniform
        
        
    def psi_set(self):
        
        theta_num = 6 # 除了 0 和 Pi 两个点之外，点的数量
        varphi_num = 21 # varphi 角度一圈上的点数
        # 总点数为 theta_num * varphi_num + 2(布洛赫球两极) # 6 * 21 + 2 = 512
        
        theta = np.delete(np.linspace(0,np.pi,theta_num+1,endpoint=False),[0])
        varphi = np.linspace(0,np.pi*2,varphi_num,endpoint=False) 
        
        psi_set = []
        for ii in range(theta_num):
            for jj in range(varphi_num):
                psi_set.append(np.mat([[np.cos(theta[ii]/2)],[np.sin(theta[ii]/2)*(np.cos(varphi[jj])+np.sin(varphi[jj])*(0+1j))]]))
        psi_set.append(np.mat([[1], [0]], dtype=complex))
        psi_set.append(np.mat([[0], [1]], dtype=complex))
        random.shuffle(psi_set) # 打乱点集
    
        training_set = psi_set[0:32]
        validation_set = psi_set[32:64]
        testing_set = psi_set[64:128]
        
        return training_set, validation_set, testing_set
        
    
    def reset(self, init_psi): # 在一个新的回合开始时，归位到开始选中的那个点上
        
        init_state = np.array([init_psi[0,0].real, init_psi[1,0].real, init_psi[0,0].imag, init_psi[1,0].imag])
        # np.array([1实，2实，1虚，2虚])
        return init_state
    
    
    def step(self, state, action, nstep):
        
        psi = np.array([state[0:int(len(state) / 2)] + state[int(len(state) / 2):int(len(state))] * 1j])
        #array([[1实 + 1虚j, 2实 + 2虚j]])
        
        psi = psi.T
        #array([[1实 + 1虚j],
        #        2实 + 2虚j]])

        psi = np.mat(psi) 
        #matrix([[ 1实 + 1虚j],
        #        [ 2实 + 2虚j]])
        
        H =  float(action)* self.s_z + 1 * self.s_x
        U = expm(-1j * H * self.dt) 
        psi = U * psi  # next state

        err = 10e-4
        fid = (np.abs(psi.H * self.target_psi) ** 2).item(0).real  
        rwd = (fid)
        
        done = (((1 - fid) < err) or nstep >= 2 * np.pi / self.dt) 

        #再将量子态的 psi 形式恢复到 state 形式。 # 因为网络输入不能为复数，否则无法用寻常基于梯度的算法进行反向传播
    
        psi = np.array(psi)
        psi_T = psi.T
        state = np.array(psi_T.real.tolist()[0] + psi_T.imag.tolist()[0]) 

        return state, rwd, done, fid   
    
def relu(vector):
    for i in range(len(vector)):
        vector[i] = np.maximum(0,vector[i])
    return vector

def agent_matrix(x):
    out = np.mat(x) * net[0].numpy() + net[1].numpy()
    out = relu(out)
    out = out * net[2].numpy() + net[3].numpy()
    out = relu(out)
    out = out * net[4].numpy() + net[5].numpy()
    out = relu(out)
    action = np.argmax(out)
    return action

def training(ep_max): 
    
    training_set = env.training_set
    validation_set = env.validation_set
    
    for i in range(ep_max):
        
        training_init_psi = random.choice(training_set)
        fid_max = 0
        print('--------------------------')
        print('训练中..., 当前回合为:', i)
        observation = env.reset(training_init_psi)
        nstep = 0
        
        while True:
            action = agent.choose_action(observation, 0.5) 
            observation_, reward, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_max = max(fid_max, fid)
            agent.store_transition(observation, action, reward, observation_)
            agent.learn()
            observation = observation_
                
            if done:
                break
            
        # if i % 1 == 33: # 每 x 个回合用验证集验证一下效果，动作全靠网络，保存最大保真度和最大奖励
        #     # print('阶段验证一下, 请稍等...')
        #     network = agent.model
        #     net = network.trainable_variables
            
        #     validation_fid_list = []
        #     validation_reward_tot_list = []
            
        #     for validation_init_psi in validation_set:
                
        #         validation_fid_max = 0
        #         validation_reward_tot = 0
                
        #         observation = env.reset(validation_init_psi)
        #         nstep = 0
                
        #         while True:
        #             action = agent_matrix(observation) 
        #             observation_, reward, done, fid = env.step(observation, action, nstep)
        #             nstep += 1
        #             validation_fid_max = max(validation_fid_max, fid)
        #             validation_reward_tot = validation_reward_tot + reward * (agent.gamma ** nstep)
        #             observation = observation_
                    
        #             if done:
        #                 break
        #         validation_fid_list.append(validation_fid_max)                
        #         validation_reward_tot_list.append(validation_reward_tot)
            
        #     validation_reward_history.append(np.mean(validation_reward_tot_list))
        #     validation_fid_history.append(np.mean(validation_fid_list))
            
        #     print('本回合验证集平均保真度: ', np.mean(validation_fid_list))
        #     # print('本回合验证集平均总奖励: ', np.mean(validation_reward_tot_list))

def testing(): # 测试 测试集中的点 得到 保真度 分布
    print('\n测试中, 请稍等...')
    
    testing_set = env.testing_set
    fid_list = []
    
    for test_init_psi in testing_set:
        
        fid_max = 0
        observation = env.reset(test_init_psi)
        nstep = 0 
        
        while True:
            action = agent_matrix(observation) 
            observation_, reward, done, fid = env.step(observation, action, nstep)  
            nstep += 1
            fid_max = max(fid_max, fid)
            observation = observation_
                
            if done:
                break
            
        fid_list.append(fid_max)
        
    return fid_list
                 
    
if __name__ == "__main__":
    
    dt = np.pi/10
    
    env = env(action_space = list(range(4)),   #允许的动作数 0 ~ 4-1 也就是 4 个
               dt = dt)
    
    agent = Agent(env.n_actions, env.n_features,
              learning_rate = 0.01,
              reward_decay = 0.9, 
              e_greedy = 0.95,
              replace_target_iter = 200,
              memory_size = 20000,
              e_greedy_increment = 0.001)
    
    # agent.model = agent.create_model()
    # agent.target_model = agent.create_model() #选择直接新建网络，如果之前没有训练好的网络 或者 想忽略原有网络 就采用此项。此项和后面关于新建网络的代码块相冲突。
    
    # 如果当前工作目录已有之前训练好的模型，就直接导入，如果没有，就创建
    folder = os.getcwd() #返回当前工作目录
    files_list = os.listdir(folder) #用于返回指定的文件夹包含的文件或文件夹的名字的列表
    if 'dqn_tf2_single_qubit_to_1_saved_model_' in files_list: #判断是否已有训练好的网络
        agent.model = keras.models.load_model('dqn_tf2_single_qubit_to_1_saved_model') #导入训练好的网络 
        agent.target_model = keras.models.load_model('dqn_tf2_single_qubit_to_1_saved_target_model')
                
    else:
        agent.model = agent.create_model()
        agent.target_model = agent.create_model()
        
    network = agent.model
    net = network.trainable_variables
        
    validation_reward_history = []
    validation_fid_history = []
        
    begin_training = time()

    # 训练模块
    # (1) 此模块可反复利用, 首次利用时， 也就是直接在脚本里运行此程序. 需要再次使用以继续训练时，就复制到工作台执行。
    # (2) 当首次调用此模块, ep_max 应以 x * 5 - 1 结尾. 而在继续训练时 ep_max 应为 x * 5 + 1, 以避免记录 保真度 和 奖励值 时出现混乱, 其中 x 为整数
    training(ep_max = 34)   
   
    end_training = time()
    
    training_time = end_training - begin_training
    
    print('\ntraining_time =',training_time)  
        
    print('各验证回合回合奖励记录为: ', validation_reward_history)
    print('各验证回合最大保真度记录为: ', validation_fid_history)

    # 测试
    testing_fid_list = testing()
    print('测试集平均保真度为：', np.mean(testing_fid_list))
    
# # 继续训练
# training(ep_max = )
# print('各验证回合回合奖励记录为: ', validation_reward_history)
# print('各验证回合最大保真度记录为: ', validation_fid_history)

# # 保存模型
# agent.save_models() 

# training_time = 97.0888741016388
# 各验证回合回合奖励记录为:  6.609857110732836
# [3.671442271916244, 3.850062953473093, 4.187260274317825, 3.3456087960519008, 3.3497329323278793, 
#   3.4737030272344125, 4.442147763339376, 4.510979608437309, 4.9756290018785645, 4.597720807954386, 
#   4.32006177187051, 4.810577566200527, 5.47946376592431, 6.280617112148929, 5.784472285046444, 
#   6.2582916134831175, 6.429794070682728, 5.872448223627309, 6.28630975895328, 6.202109197318154, 
#   6.209682439136613, 6.29487715115798, 6.519589972911713, 6.490521340265627, 6.398132427284532, 
#   6.260040494701903, 6.575289460434888, 6.131580670275586, 6.484223840715175, 6.256010320028766, 
#   6.243330286583399, 6.272477546889264, 6.253635310669641, 6.315851191709553, 6.5561936455373, 
#   6.379391080159618, 6.446132291253262, 6.324236901632354, 6.231961404922092, 6.361992549807534, 
#   6.177763820256738, 6.522177358734883, 6.513836103102932, 6.301719431328918, 6.327570123760817, 
#   6.199476932904389, 6.264313696912857, 6.277468159377481, 6.319113340304469, 6.241714565769017, 
#   6.524833928971624, 6.360365498990752, 6.609857110732836, 6.455907779738186, 6.082080695835733, 
#   6.452652880984957, 6.191714933048511, 6.253996673220078, 5.951512019131693, 6.2359184343719765, 
#   6.078475640196686, 6.13241547333566, 6.355873884585611, 6.301567848148076, 6.365465529398708, 
#   6.294674621464786, 6.467689752764026, 6.456312417263589, 6.382834446216963, 6.3641221476712335, 
#   6.215522672961125, 6.024187150697035, 6.072202567193026, 6.026934930319702, 6.136169393884402, 
#   6.301260332876584, 6.468935513243963, 6.294406849986691, 6.355666098538954, 6.362096614552028, 
#   6.341447247652647, 6.461928144863824, 6.157440008251741, 5.948586830480801, 6.348209497211318, 
#   6.011764662409822, 6.109205835435826, 6.365735926344174, 6.104188873868674, 6.30308202855311, 
#   6.079500974434484, 6.136347127877265, 6.4074859896712635, 6.198175494222809, 6.342605222040446, 
#   6.338944081458296, 6.221253622400592, 6.277382805322902, 6.558651522892456, 6.1146665421331505, 5.949954621718522]
# 各验证回合最大保真度记录为:  epi = 35  0.9971002193091636
# [0.6257186288970968, 0.9169130597622623, 0.8822339268818357, 0.8955381021812495, 0.6819964589843895, 
#   0.8994964197349602, 0.9322561831943647, 0.9490143536587328, 0.9752221572199954, 0.9525965964971024, 
#   0.8702586034515185, 0.8759363864526145, 0.9921228134580831, 0.994928868175181, 0.9939210396553269, 
#   0.9940126618414441, 0.9946364335153349, 0.992963230476606,  0.989612302434834, 0.9952415425585655, 
#   0.9949496241321334, 0.995348307147061,  0.9921380631727996, 0.9945542857655739, 0.990660773087622, 
#   0.9955109109413389, 0.9903853532629843, 0.9968705822951968, 0.9946772646823828, 0.9961913194711288, 
#   0.9952182585068149, 0.9938845241420311, 0.9954090934469306, 0.9971002193091636,0.991143878307036, 
#   0.995599545747836,  0.995730233525028,  0.9912078526054555, 0.9942641434631454, 0.9933297125192349, 
#   0.9957769529622655, 0.9940011503116974, 0.9947627913858689, 0.9933190008168853, 0.9874175555170868, 
#   0.9945082717598486, 0.9953580111535851, 0.9958098597369532, 0.9893029605383845, 0.9859880838418273, 
#   0.9949689983236156, 0.9947553906969754, 0.9860151794810974, 0.9942089090726325, 0.9962009289927006, 
#   0.9925433351735327, 0.996097652820696,  0.9962632871567698, 0.9948448646373154, 0.9866259729616449, 
#   0.9926696355326567, 0.9946844262187495, 0.9917668197199095, 0.9933665247643723, 0.9864351169687412, 
#   0.99300371383812,   0.9946010986924075, 0.9897324680868488, 0.9942446056878449, 0.9943012576003286, 
#   0.9957108900620616, 0.991129325022439,  0.9907783347984178, 0.9962563574218085, 0.9957097884111015, 
#   0.9959983596225237, 0.992585640095416,  0.9922326232977073, 0.9926927845388902, 0.9942658315939525, 
#   0.9942098254753746, 0.9937255265944402, 0.9946135423111429, 0.9970788934518674, 0.9960928420400659, 
#   0.9949280220044565, 0.9950101421898061, 0.9918733955810777, 0.9948669631642014, 0.9937179191557958, 
#   0.995742885995716,  0.9872099197127941, 0.9891815300172462, 0.9949062508409448, 0.9901019462754195, 
#   0.9939807153608645, 0.992631836775088,  0.9935249521579705, 0.9952665784477318, 0.9933575770615963, 0.9959786310865757]

# rewards: 3.67144227 3.85006295 4.18726027 3.3456088  3.34973293 3.47370303 4.44214776 4.51097961 4.975629   4.59772081 4.32006177 4.81057757 5.47946377 6.28061711 5.78447229 6.25829161 6.42979407 5.87244822 6.28630976 6.2021092  6.20968244 6.29487715 6.51958997 6.49052134 6.39813243 6.26004049 6.57528946 6.13158067 6.48422384 6.25601032 6.24333029 6.27247755 6.25363531 6.31585119 6.55619365 6.37939108 6.44613229 6.3242369  6.2319614  6.36199255 6.17776382 6.52217736 6.5138361  6.30171943 6.32757012 6.19947693 6.2643137  6.27746816 6.31911334 6.24171457 6.52483393 6.3603655  6.60985711 6.45590778 6.0820807  6.45265288 6.19171493 6.25399667 5.95151202 6.23591843 6.07847564 6.13241547 6.35587388 6.30156785 6.36546553 6.29467462 6.46768975 6.45631242 6.38283445 6.36412215 6.21552267 6.02418715 6.07220257 6.02693493 6.13616939 6.30126033 6.46893551 6.29440685 6.3556661  6.36209661 6.34144725 6.46192814 6.15744001 5.94858683 6.3482095  6.01176466 6.10920584 6.36573593 6.10418887 6.30308203 6.07950097 6.13634713 6.40748599 6.19817549 6.34260522 6.33894408 6.22125362 6.27738281 6.55865152 6.11466654 5.94995462
# fids: 0.62571863 0.91691306 0.88223393 0.8955381  0.68199646 0.89949642 0.93225618 0.94901435 0.97522216 0.9525966  0.8702586  0.87593639 0.99212281 0.99492887 0.99392104 0.99401266 0.99463643 0.99296323 0.9896123  0.99524154 0.99494962 0.99534831 0.99213806 0.99455429 0.99066077 0.99551091 0.99038535 0.99687058 0.99467726 0.99619132 0.99521826 0.99388452 0.99540909 0.99710022 0.99114388 0.99559955 0.99573023 0.99120785 0.99426414 0.99332971 0.99577695 0.99400115 0.99476279 0.993319   0.98741756 0.99450827 0.99535801 0.99580986 0.98930296 0.98598808 0.994969   0.99475539 0.98601518 0.99420891 0.99620093 0.99254334 0.99609765 0.99626329 0.99484486 0.98662597 0.99266964 0.99468443 0.99176682 0.99336652 0.98643512 0.99300371 0.9946011  0.98973247 0.99424461 0.99430126 0.99571089 0.99112933 0.99077833 0.99625636 0.99570979 0.99599836 0.99258564 0.99223262 0.99269278 0.99426583 0.99420983 0.99372553 0.99461354 0.99707889 0.99609284 0.99492802 0.99501014 0.9918734  0.99486696 0.99371792 0.99574289 0.98720992 0.98918153 0.99490625 0.99010195 0.99398072 0.99263184 0.99352495 0.99526658 0.99335758 0.99597863

# 测试集最大平均保真度为： 0.9955485434320497
